#!/usr/bin/env python

# compute statistics about the quality of the geometric segmentation
# at the level of the given OCR element

from __future__ import print_function
import argparse
import re

from lxml import html

# general utilities


def get_text(node):
    textnodes = node.xpath(".//text()")
    s = ''.join([text for text in textnodes])
    return re.sub(r'\s+', ' ', s)


def get_prop(node, name):
    title = node.get('title')
    if not title:
        return None
    props = title.split(';')
    for prop in props:
        (key, args) = prop.split(None, 1)
        if key == name:
            return args
    return None


def get_bbox(node):
    bbox = get_prop(node, 'bbox')
    if not bbox:
        return None
    return tuple([int(x) for x in bbox.split()])


# rectangle properties


def intersect(u, v):
    # intersection of two rectangles
    r = (max(u[0], v[0]), max(u[1], v[1]), min(u[2], v[2]), min(u[3], v[3]))
    return r


def area(u):
    # area of a rectangle
    return max(0, u[2] - u[0]) * max(0, u[3] - u[1])


def overlaps(u, v):
    # predicate: do the two rectangles overlap?
    return area(intersect(u, v)) > 0


def relative_overlap(u, v):
    m = max(area(u), area(v))
    i = area(intersect(u, v))
    return float(i) / m


################################################################
# main program
################################################################

# argument parsing

parser = argparse.ArgumentParser(
    description=("Compute statistics about the quality of the geometric "
                 "segmentation at the level of the given OCR element"),
    epilog=("The output is a 4-tuple (multiple,missing,error,count) "
            "for the truth compared with the actual and then again "
            "another 4-tuple in the other direction")
)
parser.add_argument(
    "truth", help="hOCR file with ground truth", type=argparse.FileType('r'))
parser.add_argument(
    "actual",
    help="hOCR file from the actual recognition",
    type=argparse.FileType('r'))
parser.add_argument(
    "-e",
    "--element",
    default="ocr_line",
    help="OCR element to look at, default: %(default)s")
parser.add_argument(
    "-o",
    "--significant_overlap",
    type=float,
    default=0.1,
    help="default: %(default)s")
parser.add_argument(
    "-c",
    "--close_match",
    type=float,
    default=0.9,
    help="default: %(default)s")
args = parser.parse_args()

# read the hOCR files

truth_doc = html.parse(args.truth)
actual_doc = html.parse(args.actual)
truth_pages = truth_doc.xpath("//*[@class='ocr_page']")
actual_pages = actual_doc.xpath("//*[@class='ocr_page']")
assert len(truth_pages) == len(actual_pages)
pages = zip(truth_pages, actual_pages)

# compute statistics


def boxstats(truths, actuals):
    multiple = 0
    missing = 0
    error = 0
    count = 0
    for t in truths:
        overlapping = [a for a in actuals if overlaps(a, t)]
        oas = [relative_overlap(t, a) for a in overlapping]
        if len([o for o in oas if o > args.significant_overlap]) > 1:
            multiple += 1
        matching = [o for o in oas if o > args.close_match]
        if len(matching) < 1:
            missing += 1
        elif len(matching) > 1:
            raise AttributeError(
                "Multiple close matches: your segmentation files are bad")
        else:
            error += 1.0 - matching[0]
            count += 1
    return multiple, missing, error, count


def check_bad_partition(boxes):
    for i in range(len(boxes)):
        for j in range(i + 1, len(boxes)):
            if relative_overlap(boxes[i], boxes[j]) > args.significant_overlap:
                return 1
    return 0


for truth, actual in pages:
    tobjs = truth.xpath("//*[@class='%s']" % args.element)
    aobjs = actual.xpath("//*[@class='%s']" % args.element)
    tboxes = [get_bbox(n) for n in tobjs]
    if check_bad_partition(tboxes):
        raise ValueError("Ground truth data is not an acceptable segmentation")
    aboxes = [get_bbox(n) for n in aobjs]
    if check_bad_partition(aboxes):
        raise ValueError("Actual data is not an acceptable segmentation")
    print(boxstats(tboxes, aboxes), boxstats(aboxes, tboxes))
